import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from scipy import stats
from statsmodels.stats.weightstats import ztest

df = pd.read_csv("New Audience Survey.csv")
df2 = pd.read_csv("2nd Batch Gov 1430 Final.csv")

answer_key = ["Human Captured", "AI-Generated", "Human Captured", "AI-Generated", "AI-Generated", "Human Captured","AI-Generated", "Human Captured","AI-Generated", "Human Captured"]
ai_gen = [2, 4, 5, 7, 9]

def compare_answers (guesses, key):
    accuracy = 0
    for i in range(len(guesses)):
        if guesses[i] == key[i]: 
            accuracy += 1
    return accuracy / len(key)

# Rename columns

df.columns = ["Respondent ID","Collector ID","Start Date","End Date","IP Address","Email Address","First Name","Last Name","Custom Data 1","collector_type_source","Q1", "Q2", "Q3", "Q4", "Q5", "Q6", "Q7", "Q8", "Q9", "Q10","Age","Device Type","Gender","Household Income","Region"]
df2.columns = ["Respondent ID","Collector ID","Start Date","End Date","IP Address","Email Address","First Name","Last Name","Custom Data 1","collector_type_source","Q1", "Q2", "Q3", "Q4", "Q5", "Q6", "Q7", "Q8", "Q9", "Q10","Age","Device Type","Gender","Household Income","Region"]

questions = ["Q1", "Q2", "Q3", "Q4", "Q5", "Q6", "Q7", "Q8", "Q9", "Q10"]

df = df.reset_index()
# Remove extra row featuring only the word "Response"
df = df.drop([0])

df2 = df2.reset_index()
# Remove extra row featuring only the word "Response"
df2 = df2.drop([0])

df.to_csv("test2.csv")
df2.to_csv("test3.csv")

# Slice out 10 answers then run compare answers on it
# print(df[["Q1", "Q2"]])
df["accuracy"] = df.apply (lambda x : compare_answers(x[questions], answer_key), axis=1)
df2["accuracy"] = df2.apply (lambda x : compare_answers(x[questions], answer_key), axis=1)

acc_s1 = []
acc_s2 = []
ai_acc = [0, 0]
human_acc = [0, 0]
s1count = len(df.index)
s2count = len(df2.index)
for i in range(len(questions)):
    if i + 1 in ai_gen:
        ans = "AI-Generated"
        ai_acc[0] += (df[questions[i]]==ans).sum()
        ai_acc[1] += (df2[questions[i]]==ans).sum()
    else:
        ans = "Human Captured"
        human_acc[0] += (df[questions[i]]==ans).sum()
        human_acc[1] += (df2[questions[i]]==ans).sum()
    acc_s1.append((df[questions[i]]==ans).sum())
    acc_s2.append((df2[questions[i]]==ans).sum())
    
acc_s1 = [(lambda x : x / s1count)(x) for x in acc_s1]
acc_s2 = [(lambda x : x / s2count)(x) for x in acc_s2]
ai_acc[0] = ai_acc[0] / (s1count*5)
ai_acc[1] = ai_acc[1] / (s2count*5)
human_acc[0] = human_acc[0] / (s1count*5)
human_acc[1] = human_acc[1] / (s2count*5)
print("AI-Generated Accuracy % Before & After Education")
print(ai_acc)
print("Human-Captured Accuracy % Before & After Education")
print(human_acc)

print("Survey 1 accuracy by question")
print(acc_s1)
print("Survey 2 accuracy by question")
print(acc_s2)

df.to_csv("test2.csv")

#Calculating overall stats
for i in range(len(questions)):
    qtable = df[questions[i]]
    qtable2 = df2[questions[i]]
    # print(qtable)
    val_noedu = qtable.value_counts(normalize=True)
    val_edu = qtable2.value_counts(normalize=True)
    if i + 1 in ai_gen:
        ans = "AI-Generated"
    else:
        ans = "Human Captured"
        # human_acc (add something)
    print(val_noedu)
    print(val_edu)
    print(ans)
    # qtable = qtable.append(pd.Series[ans])
    # qtable2 = qtable2.append(pd.Series[ans])

# ANOVA for age

def anova_calc (df):
    anova_df = df[["Age", "accuracy"]]
    # print(anova_df)
    age_groups = ["18-29", "30-44", "45-60", "> 60"]
    list_of_dicts = {}
    # Filter it for each age group
    for group in age_groups:
        tmp = anova_df[anova_df["Age"] == group]
        list_of_dicts[group] = tmp["accuracy"]
    print(list_of_dicts)
    print(stats.f_oneway(list_of_dicts[age_groups[0]], list_of_dicts[age_groups[1]], list_of_dicts[age_groups[2]], list_of_dicts[age_groups[3]]))

print("ANOVA calculations on age groups: ")
anova_calc (df)
anova_calc (df2)

print(df[["accuracy"]].agg("mean"))
print(df2[["accuracy"]].agg("mean"))

# Visualizations on By Question

plt.figure()
plt.bar(questions, acc_s1)
plt.title("Accuracy By Question")
plt.xlabel("Question")
plt.ylabel("Accuracy")
plt.show()

plt.figure()
plt.bar(questions, acc_s1, label="Without Education")
plt.bar(questions, acc_s2, width=0.4, label="With Education")
plt.legend(loc="upper left")
plt.title("Accuracy By Question: With & Without Education")
plt.xlabel("Question")
plt.ylabel("Accuracy")
plt.show()

# DEMOGRAPHIC ANALYSIS

age_group = df.groupby("Age")[["accuracy"]].agg("mean")
print(age_group)

age_group2 = df2.groupby("Age")[["accuracy"]].agg("mean")
print(age_group2)

income_group = df.groupby("Household Income")[["accuracy"]].agg("mean")
print(income_group)
income_group2 = df2.groupby("Household Income")[["accuracy"]].agg("mean")
print(income_group2)
# Nothing meaningful here - subgroups likely too small

region_group = df.groupby("Region")[["accuracy"]].agg("mean")
print(region_group)
region_group2 = df2.groupby("Region")[["accuracy"]].agg("mean")
print(region_group2)
# Nothing meaningful here - subgroups likely too small

# Make bar graphs

plt.figure()
plt.bar(age_group.index, age_group["accuracy"])
plt.title("AI Identification Accuracy by Age")
plt.xlabel("age")
plt.ylabel("accuracy")
plt.show()

plt.figure()
plt.bar(age_group.index, age_group["accuracy"], label="Without Education")
plt.bar(age_group2.index, age_group2["accuracy"], width=0.4, label="With Education")
plt.legend(loc="upper right")
plt.title("AI Identification Accuracy by Age")
plt.xlabel("age")
plt.ylabel("accuracy")
plt.show()

# Not used - region bar graphs

plt.figure()
ax = plt.axes()
plt.bar(region_group.index, region_group["accuracy"])
plt.title("AI Identification Accuracy by Region of US")
plt.xlabel("region")
plt.ylabel("accuracy")
ax.tick_params(labelsize=6)
# plt.rc('xtick', labelsize=6)
plt.show()

plt.figure()
ax = plt.axes()
plt.bar(region_group.index, region_group["accuracy"])
plt.bar(region_group2.index, region_group2["accuracy"], width=0.4)
plt.title("AI Identification Accuracy by Region of US")
plt.xlabel("region")
plt.ylabel("accuracy")
ax.tick_params(labelsize=6)
plt.show()

# AI VS Human Image Accuracies - both surveys

plt.figure()
plt.bar(["AI-Generated", "Human-Captured"], [ai_acc[0], human_acc[0]], label="Correct")
plt.bar(["AI-Generated", "Human-Captured"], [1 - ai_acc[0], 1 - human_acc[0]], bottom = [ai_acc[0], human_acc[0]], label="Incorrect")
plt.legend(loc="upper left")
plt.title("Overall Accuracy for Survey 1")
plt.ylabel("accuracy")
plt.show()

plt.figure()
plt.bar(["AI-Generated", "Human-Captured"], [ai_acc[1], human_acc[1]], label="Correct")
plt.bar(["AI-Generated", "Human-Captured"], [1 - ai_acc[1], 1 - human_acc[1]], bottom = [ai_acc[1], human_acc[1]], label="Incorrect")
plt.legend(loc="upper left")
plt.title("Overall Accuracy for Survey 2")
plt.ylabel("accuracy")
plt.show()